<!DOCTYPE html>
<html>
<head>
  <meta charset="utf-8">
  
  <title>paddlepaddle, tensorflow, mxnet, caffe2 , pytorch五大深度学习框架2017-10最新评测 | Jin Tian</title>
  <meta name="viewport" content="width=device-width, initial-scale=1, maximum-scale=1">
  
  
  
  
  <meta name="description" content="本文介绍PaddlePaddle, TensorFlow, MXNet, Caffe2, PyTorch五大深度学习框架2017-10最新评测">
<meta property="og:type" content="article">
<meta property="og:title" content="PaddlePaddle, TensorFlow, MXNet, Caffe2 , PyTorch五大深度学习框架2017-10最新评测">
<meta property="og:url" content="http://yoursite.com/2017/10/13/PaddlePaddle-TensorFlow等五大深度学习框架最新评测/index.html">
<meta property="og:site_name" content="Jin Tian">
<meta property="og:description" content="本文介绍PaddlePaddle, TensorFlow, MXNet, Caffe2, PyTorch五大深度学习框架2017-10最新评测">
<meta property="og:locale" content="zh-CN">
<meta property="og:image" content="https://i.loli.net/2017/10/12/59df2c7db0062.jpg">
<meta property="og:image" content="https://ooo.0o0.ooo/2017/10/12/59df382c2c753.jpeg">
<meta property="og:updated_time" content="2017-10-24T12:54:53.000Z">
<meta name="twitter:card" content="summary">
<meta name="twitter:title" content="PaddlePaddle, TensorFlow, MXNet, Caffe2 , PyTorch五大深度学习框架2017-10最新评测">
<meta name="twitter:description" content="本文介绍PaddlePaddle, TensorFlow, MXNet, Caffe2, PyTorch五大深度学习框架2017-10最新评测">
<meta name="twitter:image" content="https://i.loli.net/2017/10/12/59df2c7db0062.jpg">
  
    <link rel="alternate" href="/atom.xml" title="Jin Tian" type="application/atom+xml">
  

  

  <link rel="icon" href="/css/images/mylogo.jpg">
  <link rel="apple-touch-icon" href="/css/images/mylogo.jpg">
  
    <link href="//fonts.googleapis.com/css?family=Source+Code+Pro" rel="stylesheet" type="text/css">
  
  <link href="https://fonts.googleapis.com/css?family=Open+Sans|Montserrat:700" rel="stylesheet" type="text/css">
  <link href="https://fonts.googleapis.com/css?family=Roboto:400,300,300italic,400italic" rel="stylesheet" type="text/css">
  <link href="//cdn.bootcss.com/font-awesome/4.6.3/css/font-awesome.min.css" rel="stylesheet">
  <style type="text/css">
    @font-face{font-family:futura-pt;src:url(https://use.typekit.net/af/9749f0/00000000000000000001008f/27/l?subset_id=2&fvd=n5) format("woff2");font-weight:500;font-style:normal;}
    @font-face{font-family:futura-pt;src:url(https://use.typekit.net/af/90cf9f/000000000000000000010091/27/l?subset_id=2&fvd=n7) format("woff2");font-weight:500;font-style:normal;}
    @font-face{font-family:futura-pt;src:url(https://use.typekit.net/af/8a5494/000000000000000000013365/27/l?subset_id=2&fvd=n4) format("woff2");font-weight:lighter;font-style:normal;}
    @font-face{font-family:futura-pt;src:url(https://use.typekit.net/af/d337d8/000000000000000000010095/27/l?subset_id=2&fvd=i4) format("woff2");font-weight:400;font-style:italic;}</style>
  <link rel="stylesheet" href="/css/style.css">

  <script src="/js/jquery-3.1.1.min.js"></script>
  <script src="/js/bootstrap.js"></script>

  <!-- Bootstrap core CSS -->
  <link rel="stylesheet" href="/css/bootstrap.css" >

  
    <link rel="stylesheet" href="/css/dialog.css">
  

  

  
    <link rel="stylesheet" href="/css/header-post.css" >
  

  
  
  
    <link rel="stylesheet" href="/css/vdonate.css" >
  

</head>



  <body data-spy="scroll" data-target="#toc" data-offset="50">


  
  <div id="container">
    <div id="wrap">
      
        <header>

    <div id="allheader" class="navbar navbar-default navbar-static-top" role="navigation">
        <div class="navbar-inner">
          
          <div class="container"> 
            <button type="button" class="navbar-toggle" data-toggle="collapse" data-target=".navbar-collapse">
              <span class="sr-only">Toggle navigation</span>
              <span class="icon-bar"></span>
              <span class="icon-bar"></span>
              <span class="icon-bar"></span>
            </button>

            
              <a class="brand" style="
                 border-width: 0px;  margin-top: 0px;"  
                href="#" data-toggle="modal" data-target="#myModal" >
                  <img width="124px" height="124px" alt="Hike News" src="/css/images/mylogo.jpg">
              </a>
            
            
            <div class="navbar-collapse collapse">
              <ul class="hnav navbar-nav">
                
                  <li> <a class="main-nav-link" href="/">首页</a> </li>
                
                  <li> <a class="main-nav-link" href="/archives">归档</a> </li>
                
                  <li> <a class="main-nav-link" href="/categories">分类</a> </li>
                
                  <li> <a class="main-nav-link" href="/tags">标签</a> </li>
                
                  <li> <a class="main-nav-link" href="/about">关于</a> </li>
                
                  <li> <a class="main-nav-link" href="http://luoli-luoli.com/chat">chat</a> </li>
                
                  <li><div id="search-form-wrap">

    <form class="search-form">
        <input type="text" class="ins-search-input search-form-input" placeholder="" />
        <button type="submit" class="search-form-submit"></button>
    </form>
    <div class="ins-search">
    <div class="ins-search-mask"></div>
    <div class="ins-search-container">
        <div class="ins-input-wrapper">
            <input type="text" class="ins-search-input" placeholder="请输入关键词..." />
            <span class="ins-close ins-selectable"><i class="fa fa-times-circle"></i></span>
        </div>
        <div class="ins-section-wrapper">
            <div class="ins-section-container"></div>
        </div>
    </div>
</div>
<script>
(function (window) {
    var INSIGHT_CONFIG = {
        TRANSLATION: {
            POSTS: '文章',
            PAGES: '页面',
            CATEGORIES: '分类',
            TAGS: '标签',
            UNTITLED: '(无标题)',
        },
        ROOT_URL: '/',
        CONTENT_URL: '/content.json',
    };
    window.INSIGHT_CONFIG = INSIGHT_CONFIG;
})(window);
</script>
<script src="/js/insight.js"></script>

</div></li>
            </div>
          </div>
                
      </div>
    </div>

</header>



      
            
      <div id="content" class="outer">
        
          <section id="main" style="float:none;"><article id="post-PaddlePaddle-TensorFlow等五大深度学习框架最新评测" style="width: 75%; float:left;" class="article article-type-post" itemscope itemprop="blogPost" >
  <div id="articleInner" class="article-inner">
    
    
      <header class="article-header">
        
  
    <h1 class="thumb" class="article-title" itemprop="name">
      PaddlePaddle, TensorFlow, MXNet, Caffe2 , PyTorch五大深度学习框架2017-10最新评测
    </h1>
  

      </header>
    
    <div class="article-meta">
      
	<a href="/2017/10/13/PaddlePaddle-TensorFlow等五大深度学习框架最新评测/" class="article-date">
	  <time datetime="2017-10-13T06:55:55.000Z" itemprop="datePublished">2017-10-13</time>
	</a>

      
    <a class="article-category-link" href="/categories/默认分类/">默认分类</a>

      
	<a class="article-views">
	<span id="busuanzi_container_page_pv">
		阅读量<span id="busuanzi_value_page_pv"></span>
	</span>
	</a>

    </div>
    <div class="article-entry" itemprop="articleBody">
      
        <p>本文介绍PaddlePaddle, TensorFlow, MXNet, Caffe2, PyTorch五大深度学习框架2017-10最新评测<br><a id="more"></a></p>
<h1 id="PaddlePaddle-TensorFlow-MXNet-Caffe2-PyTorch五大深度学习框架2017-10最新评测"><a href="#PaddlePaddle-TensorFlow-MXNet-Caffe2-PyTorch五大深度学习框架2017-10最新评测" class="headerlink" title="PaddlePaddle, TensorFlow, MXNet, Caffe2 , PyTorch五大深度学习框架2017-10最新评测"></a>PaddlePaddle, TensorFlow, MXNet, Caffe2 , PyTorch五大深度学习框架2017-10最新评测</h1><h2 id="前言"><a href="#前言" class="headerlink" title="前言"></a>前言</h2><p>本文将是2017下半年以来，最新也是最全的一个深度学习框架评测。这里的评测并不是简单的使用评测，我们将用这五个框架共同完成一个深度学习任务，从框架使用的易用性、训练的速度、数据预处理的繁琐程度，以及显存占用大小等几个方面来进行全方位的测评，除此之外，我们还将给出一个非常客观，非常全面的使用建议。最后提醒大家<strong>本篇文章不仅仅是一个评测，你甚至可以作为五大框架的入门教程</strong>。</p>
<h2 id="0-五大框架概览"><a href="#0-五大框架概览" class="headerlink" title="0. 五大框架概览"></a>0. 五大框架概览</h2><p>在评测之前，让我们先对这五大框架进行一个全方位的概览，以及他们目前所处的发展地位。首先在这五大框架中，很多人肯定会问，为什么没有Keras？为什么没有CNTK？在这里我说明一点，本篇文章<strong>偏向于工业化级别的应用评测</strong>，主要评测主流框架，当然不是说Keras和CNTK就不主流了，文章没有任何利益相关的东西，只不过是Keras本身就拥有多种框架作为后端，因此与它的后端框架对比也就没有任何意义，Keras毫无疑问是速度最慢的。而CNTK由于笔者对Windows无感因此也就没有在评测范围之内(CNTK也是一个优秀的框架，当然也跨平台，感兴趣者可以去踩踩坑)。</p>
<p>TensorFlow可以说是目前发展来说最活跃的，TensorFlow目前已经有<strong>72.3k</strong>个star，MXNet是<strong>11.5k</strong>，Caffe2是<strong>5.9K</strong>, 当然caffe2要推出的稍晚一些，MXNet的官方GitHub repo也是后来又转到Apache的孵化项目中。但是从GitHub受关注度来看，无疑TensorFlow和MXNet是更被看好的。</p>
<p>即使我不做这篇测评，很多人也知道这些框架目前为止有一些这样的评价：</p>
<ul>
<li>TensorFlow API比较繁杂，使用上手困难，乱七八糟的东西很多，但是生态丰富，很多深度学习模型多有TF的实现，有Google大佬加持；</li>
<li>MXNet 占用内存小，速度快，非常小巧玲珑，有着天生的开源基因，完全靠社区推动的框架；</li>
<li>Caffe2 是面向工业级应用的框架，但是推出较晚，而且主打Python2(execuse me? 2017年了还主打Python2？), 我不由自主的黑一下，从安装部署角度来说用户体验不是非常友好；</li>
<li>PyTorch 是Facebook面向学术界推出的一个框架，使用非常简单，搭建神经网络就像Keras和matlab一样，但是我又不得不黑一下，每次还得判断一下是GPU还是CPU？(execuse me? 真的应了那句话，我踩过了tf的坑才知道tf的好)；</li>
<li>PaddlePadddle 百度开源的一个框架，国内也有很多人用，我的感受是，非常符合中国人的使用习惯，但是在API的实用性上还有待进一步加强，我曾经写过一篇博客入门PaddlePaddle，不得不说，PaddlePaddle的中文文档写的非常清楚，上手比较简单<a href="https://jinfagang.github.io/2017/10/10/paddlepaddle系列之三行代码从入门到精通/" target="_blank" rel="noopener">PaddlePaddle三行代码从入门到精通</a>;</li>
</ul>
<p>以上评价是以前的评价，夹杂着一丝个人使用感受，最后说一下他们各自目前的好的动向：</p>
<ul>
<li>TensorFlow models这个模型库更新非常快，以前的一些图片分类，目标检测，图片生成文字，生成对抗网络都有现成的深度学习应用的例子，包括现在更新的基于知识图谱的问答项目，神经网络编程机器人等项目，这些官方生态对于一个框架来说非常有用，这无疑是tf的一个长处</li>
<li>MXNet早在几个月前就推出了Gluon这个接口，说白了就是一个Keras，包装了一个更加方便使用的API，但是目前来说还只能实现一些简单的网络的构建，复杂的还是得用原生的API，这里有一个教程链接<a href="https://github.com/mli/cvpr17.git" target="_blank" rel="noopener">Gluon资料</a>， 除此之外，MXNet也有一个实例仓库，其中有一些有意思的项目比如语音识别，但是感觉实现的非常不友好，代码几乎凌乱不堪；</li>
<li>Caffe2 Caffe2相对于前面两者来说可以说非常弱了，没有丝毫亮点，说好的一个C++高速工业级框架的呢？除了吹牛逼忽悠大众能搞些有用的官方使用文档或者教程出来吗？不好多说什么。</li>
<li>PyTorch就一笔带过了，偏向于学术快速实现，要工业级应用，比如做个模型跑到服务器上或者安卓手机上或者嵌入式上应该搞不来；</li>
<li>PaddlePaddle 现在做的还不错，我强调一句，<strong>Paddle是唯一一个不配置任何第三方库，克隆直接make就能成功的框架</strong>,  被caffe编译虐过的人应该对此深有感触。</li>
</ul>
<p>说了这么多，相信大家对目前的框架有了一个大致的了解，那么接下来我们就用其中几个框架来完成分类图片这么一个任务吧，这里面将包含<strong>图片如何导入模型</strong>,  <strong>如何写网络</strong>， <strong>整个训练的Pipeline</strong>等内容。</p>
<p>我们此次评测的任务是图片分类，大家尝试任何一个框架只需要新建一个文件夹，比如<code>mxnet_classifier</code>, 把数据扔到 <code>data</code> 里即可，我们侧重评测数据预处理的复杂程度，和网络编写的复杂程度。</p>
<p>图片下载地址<a href="http://vision.stanford.edu/aditya86/ImageNetDogs/images.tar" target="_blank" rel="noopener">images.tar</a> , <a href="http://vision.stanford.edu/aditya86/ImageNetDogs/annotation.tar" target="_blank" rel="noopener">annotations.tar</a>. 解压之后得到：</p>
<figure class="highlight css"><table><tr><td class="code"><pre><div class="line"><span class="selector-tag">paddle_test</span></div><div class="line">└── <span class="selector-tag">data</span></div><div class="line">    ├── <span class="selector-tag">annotation</span><span class="selector-class">.tar</span></div><div class="line">    └── <span class="selector-tag">images</span><span class="selector-class">.tar</span></div></pre></td></tr></table></figure>
<p>解压之后Images下面每一个文件夹是一个类别的狗， 其实分类任务我们只要这个就可以了。</p>
<h2 id="1-MXNet"><a href="#1-MXNet" class="headerlink" title="1. MXNet"></a>1. MXNet</h2><p>首先上场的，用MXNet吧。建议大家看一下上面我贴出的Gluon李沐大神写的PPT，包含了Gluon和其他框架的区别，以及MXNet在多GPU上训练的优势。</p>
<p>没有安装的安装一下：</p>
<figure class="highlight cmake"><table><tr><td class="code"><pre><div class="line">sudo pip3 <span class="keyword">install</span> mxnet</div><div class="line">sudo pip3 <span class="keyword">install</span> mxnet-cu80</div><div class="line">sudo pip3 <span class="keyword">install</span> mxnet-cu80mkl</div></pre></td></tr></table></figure>
<p>分别是CPU乞丐版，GPU土豪版，GPU加CPU加速至尊豪华版。安装完了你应该clone一下mxnet的源代码，从tools里面找到im2rec.py这个工具，我们做图片，不管是检测还是分割还是分类，都按照mxnet的逻辑把图片转成二进制的rec格式吧。</p>
<p>我们现在有了Images文件夹，用<code>im2rec.py</code>处理参数这样写：</p>
<figure class="highlight brainfuck"><table><tr><td class="code"><pre><div class="line"><span class="comment">python3</span> <span class="comment">im2rec</span><span class="string">.</span><span class="comment">py</span> <span class="comment">standford_dogs</span> <span class="comment">Images/</span> <span class="literal">-</span><span class="literal">-</span><span class="comment">list</span> <span class="comment">true</span>  <span class="literal">-</span><span class="literal">-</span><span class="comment">recursive</span> <span class="comment">true</span> <span class="literal">-</span><span class="literal">-</span><span class="comment">train</span><span class="literal">-</span><span class="comment">ratio</span> <span class="comment">0</span><span class="string">.</span><span class="comment">8</span> <span class="literal">-</span><span class="literal">-</span><span class="comment">test</span><span class="literal">-</span><span class="comment">ratio</span> <span class="comment">0</span><span class="string">.</span><span class="comment">2</span></div></pre></td></tr></table></figure>
<p>这一步会生成两个文件：</p>
<ul>
<li>standford_dogs_train.lst</li>
<li>standford_dogs_test.lst</li>
</ul>
<p><code>standford_dogs</code> 是前缀， —list true表示生成列表，recursive用户这种每一个文件夹代表一类的情况，最后在<code>standford_dogs_train.lst</code> 里面的一行是这样的：</p>
<figure class="highlight dns"><table><tr><td class="code"><pre><div class="line"><span class="number">5008</span>	<span class="number">27.000000</span>	n<span class="number">02092339</span>-Weimaraner/n020<span class="number">92339_2885</span>.jpg</div><div class="line"><span class="number">5092</span>	<span class="number">27.000000</span>	n<span class="number">02092339</span>-Weimaraner/n020<span class="number">92339_6548</span>.jpg</div></pre></td></tr></table></figure>
<p>第一个数字是图片的总数目的index，第二个应该是类别的index但是这个.0000有点不可思议。好了，有了这个lst文件我们继续用im2rec来生成rec二进制数据吧, 这一步非常简单了，直接load上面的prefix和Images这个图片根目录即可：</p>
<figure class="highlight stylus"><table><tr><td class="code"><pre><div class="line">python3 im2rec<span class="selector-class">.py</span> standford_dogs Images/</div></pre></td></tr></table></figure>
<p>mxnet会依次生成train和test的rec文件：</p>
<p><img src="https://i.loli.net/2017/10/12/59df2c7db0062.jpg" alt=""></p>
<p>OK, mxnet做数据集也不是非常的麻烦，这个过程如果满分五分的话我给4分，pytorch如果不考虑性能的话应该是最直接的，直接从文件夹导入，但是rec格式更快。生成之后总共有了2.8G的文件。</p>
<p>好了，数据准备了，直接写一个网络开始训练罗？我要写一个vgg怎么办？我要看论文吗？我要从第一层开始看网络结构吗？我要换ResNet怎么办？要换Inception怎么办？没有关系！mxnet 官方example包含了大多数这些网络结构！！</p>
<figure class="highlight stylus"><table><tr><td class="code"><pre><div class="line">├── alexnet<span class="selector-class">.py</span></div><div class="line">├── googlenet<span class="selector-class">.py</span></div><div class="line">├── inception-bn<span class="selector-class">.py</span></div><div class="line">├── inception-resnet-v2<span class="selector-class">.py</span></div><div class="line">├── inception-v3<span class="selector-class">.py</span></div><div class="line">├── inception-v4<span class="selector-class">.py</span></div><div class="line">├── lenet<span class="selector-class">.py</span></div><div class="line">├── mlp<span class="selector-class">.py</span></div><div class="line">├── mobilenet<span class="selector-class">.py</span></div><div class="line">├── resnet-v1<span class="selector-class">.py</span></div><div class="line">├── resnet<span class="selector-class">.py</span></div><div class="line">├── resnext<span class="selector-class">.py</span></div><div class="line">└── vgg.py</div></pre></td></tr></table></figure>
<p>更重要的是，我们看看alexnet的代码：</p>
<figure class="highlight nix"><table><tr><td class="code"><pre><div class="line"><span class="built_in">import</span> mxnet as mx</div><div class="line"><span class="built_in">import</span> numpy as np</div><div class="line"></div><div class="line">def get_symbol(num_classes, <span class="attr">dtype='float32',</span> **kwargs):</div><div class="line">    <span class="attr">input_data</span> = mx.sym.Variable(<span class="attr">name="data")</span></div><div class="line">    <span class="keyword">if</span> <span class="attr">dtype</span> == 'float16':</div><div class="line">        <span class="attr">input_data</span> = mx.sym.Cast(<span class="attr">data=input_data,</span> <span class="attr">dtype=np.float16)</span></div><div class="line">    <span class="comment"># stage 1</span></div><div class="line">    <span class="attr">conv1</span> = mx.sym.Convolution(<span class="attr">name='conv1',</span></div><div class="line">        <span class="attr">data=input_data,</span> <span class="attr">kernel=(11,</span> <span class="number">11</span>), <span class="attr">stride=(4,</span> <span class="number">4</span>), <span class="attr">num_filter=96)</span></div><div class="line">    <span class="attr">relu1</span> = mx.sym.Activation(<span class="attr">data=conv1,</span> <span class="attr">act_type="relu")</span></div><div class="line">    <span class="attr">lrn1</span> = mx.sym.LRN(<span class="attr">data=relu1,</span> <span class="attr">alpha=0.0001,</span> <span class="attr">beta=0.75,</span> <span class="attr">knorm=2,</span> <span class="attr">nsize=5)</span></div><div class="line">    <span class="attr">pool1</span> = mx.sym.Pooling(</div><div class="line">        <span class="attr">data=lrn1,</span> <span class="attr">pool_type="max",</span> <span class="attr">kernel=(3,</span> <span class="number">3</span>), <span class="attr">stride=(2,2))</span></div><div class="line">    <span class="comment"># stage 2</span></div><div class="line">    <span class="attr">conv2</span> = mx.sym.Convolution(<span class="attr">name='conv2',</span></div><div class="line">        <span class="attr">data=pool1,</span> <span class="attr">kernel=(5,</span> <span class="number">5</span>), <span class="attr">pad=(2,</span> <span class="number">2</span>), <span class="attr">num_filter=256)</span></div><div class="line">    <span class="attr">relu2</span> = mx.sym.Activation(<span class="attr">data=conv2,</span> <span class="attr">act_type="relu")</span></div><div class="line">    <span class="attr">lrn2</span> = mx.sym.LRN(<span class="attr">data=relu2,</span> <span class="attr">alpha=0.0001,</span> <span class="attr">beta=0.75,</span> <span class="attr">knorm=2,</span> <span class="attr">nsize=5)</span></div><div class="line">    <span class="attr">pool2</span> = mx.sym.Pooling(<span class="attr">data=lrn2,</span> <span class="attr">kernel=(3,</span> <span class="number">3</span>), <span class="attr">stride=(2,</span> <span class="number">2</span>), <span class="attr">pool_type="max")</span></div><div class="line">    <span class="comment"># stage 3</span></div><div class="line">    <span class="attr">conv3</span> = mx.sym.Convolution(<span class="attr">name='conv3',</span></div><div class="line">        <span class="attr">data=pool2,</span> <span class="attr">kernel=(3,</span> <span class="number">3</span>), <span class="attr">pad=(1,</span> <span class="number">1</span>), <span class="attr">num_filter=384)</span></div><div class="line">    <span class="attr">relu3</span> = mx.sym.Activation(<span class="attr">data=conv3,</span> <span class="attr">act_type="relu")</span></div><div class="line">    <span class="attr">conv4</span> = mx.sym.Convolution(<span class="attr">name='conv4',</span></div><div class="line">        <span class="attr">data=relu3,</span> <span class="attr">kernel=(3,</span> <span class="number">3</span>), <span class="attr">pad=(1,</span> <span class="number">1</span>), <span class="attr">num_filter=384)</span></div><div class="line">    <span class="attr">relu4</span> = mx.sym.Activation(<span class="attr">data=conv4,</span> <span class="attr">act_type="relu")</span></div><div class="line">    <span class="attr">conv5</span> = mx.sym.Convolution(<span class="attr">name='conv5',</span></div><div class="line">        <span class="attr">data=relu4,</span> <span class="attr">kernel=(3,</span> <span class="number">3</span>), <span class="attr">pad=(1,</span> <span class="number">1</span>), <span class="attr">num_filter=256)</span></div><div class="line">    <span class="attr">relu5</span> = mx.sym.Activation(<span class="attr">data=conv5,</span> <span class="attr">act_type="relu")</span></div><div class="line">    <span class="attr">pool3</span> = mx.sym.Pooling(<span class="attr">data=relu5,</span> <span class="attr">kernel=(3,</span> <span class="number">3</span>), <span class="attr">stride=(2,</span> <span class="number">2</span>), <span class="attr">pool_type="max")</span></div><div class="line">    <span class="comment"># stage 4</span></div><div class="line">    <span class="attr">flatten</span> = mx.sym.Flatten(<span class="attr">data=pool3)</span></div><div class="line">    <span class="attr">fc1</span> = mx.sym.FullyConnected(<span class="attr">name='fc1',</span> <span class="attr">data=flatten,</span> <span class="attr">num_hidden=4096)</span></div><div class="line">    <span class="attr">relu6</span> = mx.sym.Activation(<span class="attr">data=fc1,</span> <span class="attr">act_type="relu")</span></div><div class="line">    <span class="attr">dropout1</span> = mx.sym.Dropout(<span class="attr">data=relu6,</span> <span class="attr">p=0.5)</span></div><div class="line">    <span class="comment"># stage 5</span></div><div class="line">    <span class="attr">fc2</span> = mx.sym.FullyConnected(<span class="attr">name='fc2',</span> <span class="attr">data=dropout1,</span> <span class="attr">num_hidden=4096)</span></div><div class="line">    <span class="attr">relu7</span> = mx.sym.Activation(<span class="attr">data=fc2,</span> <span class="attr">act_type="relu")</span></div><div class="line">    <span class="attr">dropout2</span> = mx.sym.Dropout(<span class="attr">data=relu7,</span> <span class="attr">p=0.5)</span></div><div class="line">    <span class="comment"># stage 6</span></div><div class="line">    <span class="attr">fc3</span> = mx.sym.FullyConnected(<span class="attr">name='fc3',</span> <span class="attr">data=dropout2,</span> <span class="attr">num_hidden=num_classes)</span></div><div class="line">    <span class="keyword">if</span> <span class="attr">dtype</span> == 'float16':</div><div class="line">        <span class="attr">fc3</span> = mx.sym.Cast(<span class="attr">data=fc3,</span> <span class="attr">dtype=np.float32)</span></div><div class="line">    <span class="attr">softmax</span> = mx.sym.SoftmaxOutput(<span class="attr">data=fc3,</span> <span class="attr">name='softmax')</span></div><div class="line">    return softmax</div></pre></td></tr></table></figure>
<p>非常非常非常简洁！！！！，只是一个函数，唯一不同的就是类别的数目不同，最后函数根据类别不同返回一个softmax的loss。</p>
<p>最后我们看看怎么把数据导入，然后训练的！！！</p>
<figure class="highlight python"><table><tr><td class="code"><pre><div class="line"><span class="string">"""</span></div><div class="line">train pipe line in mxnet</div><div class="line">"""</div><div class="line"><span class="keyword">import</span> mxnet <span class="keyword">as</span> mx</div><div class="line"><span class="keyword">from</span> symbols.vgg <span class="keyword">import</span> get_vgg</div><div class="line"></div><div class="line"></div><div class="line"><span class="function"><span class="keyword">def</span> <span class="title">train</span><span class="params">()</span>:</span></div><div class="line">    num_classes = <span class="number">120</span></div><div class="line">    batch_size = <span class="number">64</span></div><div class="line">    <span class="comment"># shape not have to be it exactly are</span></div><div class="line">    data_shape = (<span class="number">3</span>, <span class="number">64</span>, <span class="number">64</span>)</div><div class="line">    num_epoch = <span class="number">50</span></div><div class="line">    prefix = <span class="string">'standford_dogs_model'</span></div><div class="line"></div><div class="line">    train_iter = mx.io.ImageRecordIter(</div><div class="line">        path_imgrec=<span class="string">"data/standford_dogs_train.rec"</span>,</div><div class="line">        data_shape=data_shape,</div><div class="line">        batch_size=batch_size,</div><div class="line">    )</div><div class="line"></div><div class="line">    val_iter = mx.io.ImageRecordIter(</div><div class="line">        path_imgrec=<span class="string">"data/standford_dogs_test.rec"</span>,</div><div class="line">        data_shape=data_shape,</div><div class="line">        batch_size=batch_size,</div><div class="line">    )</div><div class="line"></div><div class="line">    model = mx.model.FeedForward(</div><div class="line">        <span class="comment"># set mx.gpu(0, 1) for multiple gpu</span></div><div class="line">        ctx=mx.cpu(),</div><div class="line">        symbol=get_vgg(num_classes=num_classes),</div><div class="line">        num_epoch=num_epoch,</div><div class="line">        learning_rate=<span class="number">0.01</span>,</div><div class="line">    )</div><div class="line"></div><div class="line">    model.fit(</div><div class="line">        X=train_iter,</div><div class="line">        eval_data=val_iter,</div><div class="line">        <span class="comment"># every 10 iteration log info</span></div><div class="line">        batch_end_callback=mx.callback.Speedometer(batch_size, <span class="number">10</span>),</div><div class="line">        epoch_end_callback=mx.callback.do_checkpoint(prefix=prefix)</div><div class="line">    )</div><div class="line"></div><div class="line"></div><div class="line"><span class="keyword">if</span> __name__ == <span class="string">'__main__'</span>:</div><div class="line">    train()</div></pre></td></tr></table></figure>
<p>尼玛，简直简单到想哭。大家注意这里get_vgg就是直接从官方的example/image-classification里面拿的，我们训练一个vgg看看。运行之后发现网络已经跑起来了：</p>
<p><img src="https://ooo.0o0.ooo/2017/10/12/59df382c2c753.jpeg" alt=""></p>
<p>温馨提示一下，MXNet貌似已经摒弃了上面的写法，上面的写法和PyTorch一样，是一种生成式的写法，Model和Module的区别就是，后者更加Tensor化，也就是图化，运行之前先把GPU占领一下再说。</p>
<p>OK， MXNet的坑已经踩完了。我来总结一下MXNet不为人知的几点：</p>
<ul>
<li><strong>这是一个良心框架</strong>。可以看出它的开发者再用心的追求速度和易用性，否则也不会推出Gluon这个接口了，这个接口就是让普通开发者更加易用，同时追求速度；</li>
<li>MXNet是唯一一个比较中立的框架，你要知道，Google推出TensorFlow可是有小九九的，其内部至少有几套速度更快的纯C写的版本，否则TensorFlow怎么那么慢？不拉开差距怎么来的KPI？怎么让全球开发者为Google服务？(不是Google员工也是不是Google敌对员工，逃…)</li>
<li>MXNet的未来潜力很大，我最近在研究MXNet构建复杂的网络，比如Cycle-GAN，比如Seq2Seq的实现，但是不得不承认，这方面TensorFlow更加强大…</li>
</ul>
<h2 id="2-PaddlePaddle"><a href="#2-PaddlePaddle" class="headerlink" title="2. PaddlePaddle"></a>2. PaddlePaddle</h2><p>为什么第二个评测用PaddlePaddle？第一，它最近表现很好，但是知道人很少，秉着为开发者引路的原则，增加以下曝光度，其实说实话，很多人不知道PaddlePaddle已经升级到了v2的Python API，而且内部还引入很多Go语言的代码，我没有仔细看这些代码是用来干啥的，但是很显然，PaddlePaddle在追求速度。</p>
<p>对Paddle的评测我这里列举以下Paddle的几个亮点的地方：</p>
<ul>
<li>相对来说更易用的API，所谓相对是因为，它还是有一些冗杂的地方；</li>
<li>占用内存小，速度快，Paddle在百度内部应该也服务了相当多的项目，因此工业应用不成问题;</li>
<li>中文支持，不想国外的框架，PaddlePaddle还是有着相当多的中文文档的；</li>
<li>PaddlePaddle在自然语言处理上有很多现成的历程，比如情感分类，甚至是语音识别都有Demo；</li>
<li>PaddlePaddle支持多机多卡训练，也算是集大成者。</li>
</ul>
<p>关于PaddlePaddle使用的Pipeline异步到我之前写的一个文章<a href="https://jinfagang.github.io/2017/10/10/paddlepaddle系列之三行代码从入门到精通/" target="_blank" rel="noopener">传送门</a>。</p>
<h2 id="3-TensorFlow"><a href="#3-TensorFlow" class="headerlink" title="3. TensorFlow"></a>3. TensorFlow</h2><p>关于tf，还真的是爱恨交加，从刚入手到现在，他的API的繁杂性以及训练的繁琐几乎让人望而却步，不过好在它有一个非常强大的生态。我们来看看TensorFlow做分类任务应该怎么做。</p>
<p>首先，毫无疑问，最好的方法是把图片放到tfrecord这个文件类型中去。但是如何生成tfrecord是个蛋疼的问题，在这里我申明一点，tfrecord和MXNet的rec文件不同：</p>
<p><em>tfrecod是将文件以键值对的形式存放起来了，每个记录就是一个example，而MXNet存储需要先建立一个lst，然后从lst转成二进制文件。好吧其实也差不多，不过你应该能理解我说的意思。</em></p>
<p>我们看一下一个用来将图片转为tfrecord的代码：</p>
<figure class="highlight python"><table><tr><td class="code"><pre><div class="line"><span class="keyword">from</span> __future__ <span class="keyword">import</span> absolute_import</div><div class="line"><span class="keyword">from</span> __future__ <span class="keyword">import</span> division</div><div class="line"><span class="keyword">from</span> __future__ <span class="keyword">import</span> print_function</div><div class="line"><span class="keyword">from</span> datetime <span class="keyword">import</span> datetime</div><div class="line"><span class="keyword">import</span> os</div><div class="line"><span class="keyword">import</span> random</div><div class="line"><span class="keyword">import</span> sys</div><div class="line"><span class="keyword">import</span> threading</div><div class="line"><span class="keyword">import</span> numpy <span class="keyword">as</span> np</div><div class="line"><span class="keyword">import</span> tensorflow <span class="keyword">as</span> tf</div><div class="line"></div><div class="line"></div><div class="line"><span class="class"><span class="keyword">class</span> <span class="title">TFRecordsGenerator</span><span class="params">(object)</span>:</span></div><div class="line">    <span class="string">"""</span></div><div class="line">    this class is using for tf_records generations in image classification use</div><div class="line">    For usages:</div><div class="line">    All images must contains in different folders, TFRecordsGenerator will traverse</div><div class="line">    all folders and find different classes.</div><div class="line">    """</div><div class="line"></div><div class="line">    <span class="function"><span class="keyword">def</span> <span class="title">__init__</span><span class="params">(self,</span></span></div><div class="line">                 name,</div><div class="line">                 images_dir,</div><div class="line">                 classes_file_path,</div><div class="line">                 tf_records_save_dir,</div><div class="line">                 num_shards=<span class="number">4</span>,</div><div class="line">                 num_threads=<span class="number">4</span>):</div><div class="line">        self.name = name</div><div class="line">        self.classes_file_path = classes_file_path</div><div class="line">        self.images_dir = images_dir</div><div class="line">        self.tf_records_saved_dir = tf_records_save_dir</div><div class="line">        self.num_shards = num_shards</div><div class="line">        self.num_threads = num_threads</div><div class="line"></div><div class="line"><span class="meta">    @staticmethod</span></div><div class="line">    <span class="function"><span class="keyword">def</span> <span class="title">_int64_feature</span><span class="params">(value)</span>:</span></div><div class="line">        <span class="keyword">if</span> <span class="keyword">not</span> isinstance(value, list):</div><div class="line">            value = [value]</div><div class="line">        <span class="keyword">return</span> tf.train.Feature(int64_list=tf.train.Int64List(value=value))</div><div class="line"></div><div class="line"><span class="meta">    @staticmethod</span></div><div class="line">    <span class="function"><span class="keyword">def</span> <span class="title">_bytes_feature</span><span class="params">(value)</span>:</span></div><div class="line">        <span class="keyword">return</span> tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))</div><div class="line"></div><div class="line">    <span class="function"><span class="keyword">def</span> <span class="title">_convert_to_example</span><span class="params">(self, filename, image_buffer, label, text, height, width)</span>:</span></div><div class="line">        <span class="string">"""</span></div><div class="line">        Example for image classification</div><div class="line">        :param filename:</div><div class="line">        :param image_buffer:</div><div class="line">        :param label:</div><div class="line">        :param text:</div><div class="line">        :param height:</div><div class="line">        :param width:</div><div class="line">        :return:</div><div class="line">        """</div><div class="line">        color_space = <span class="string">'RGB'</span></div><div class="line">        channels = <span class="number">3</span></div><div class="line">        image_format = <span class="string">'JPEG'</span></div><div class="line">        example = tf.train.Example(features=tf.train.Features(feature=&#123;</div><div class="line">            <span class="string">'image/height'</span>: self._int64_feature(height),</div><div class="line">            <span class="string">'image/width'</span>: self._int64_feature(width),</div><div class="line">            <span class="string">'image/color_space'</span>: self._bytes_feature(tf.compat.as_bytes(color_space)),</div><div class="line">            <span class="string">'image/channels'</span>: self._int64_feature(channels),</div><div class="line">            <span class="string">'image/class/label'</span>: self._int64_feature(label),</div><div class="line">            <span class="string">'image/class/text'</span>: self._bytes_feature(tf.compat.as_bytes(text)),</div><div class="line">            <span class="string">'image/format'</span>: self._bytes_feature(tf.compat.as_bytes(image_format)),</div><div class="line">            <span class="string">'image/filename'</span>: self._bytes_feature(tf.compat.as_bytes(os.path.basename(filename))),</div><div class="line">            <span class="string">'image/encoded'</span>: self._bytes_feature(tf.compat.as_bytes(image_buffer))&#125;))</div><div class="line">        <span class="keyword">return</span> example</div><div class="line"></div><div class="line">    <span class="class"><span class="keyword">class</span> <span class="title">ImageCoder</span><span class="params">(object)</span>:</span></div><div class="line">        <span class="function"><span class="keyword">def</span> <span class="title">__init__</span><span class="params">(self)</span>:</span></div><div class="line">            self._sess = tf.Session()</div><div class="line">            self._png_data = tf.placeholder(dtype=tf.string)</div><div class="line">            image = tf.image.decode_png(self._png_data, channels=<span class="number">3</span>)</div><div class="line">            self._png_to_jpeg = tf.image.encode_jpeg(image, format=<span class="string">'rgb'</span>, quality=<span class="number">100</span>)</div><div class="line">            self._decode_jpeg_data = tf.placeholder(dtype=tf.string)</div><div class="line">            self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=<span class="number">3</span>)</div><div class="line"></div><div class="line">        <span class="function"><span class="keyword">def</span> <span class="title">png_to_jpeg</span><span class="params">(self, image_data)</span>:</span></div><div class="line">            <span class="keyword">return</span> self._sess.run(self._png_to_jpeg,</div><div class="line">                                  feed_dict=&#123;self._png_data: image_data&#125;)</div><div class="line"></div><div class="line">        <span class="function"><span class="keyword">def</span> <span class="title">decode_jpeg</span><span class="params">(self, image_data)</span>:</span></div><div class="line">            image = self._sess.run(self._decode_jpeg,</div><div class="line">                                   feed_dict=&#123;self._decode_jpeg_data: image_data&#125;)</div><div class="line">            <span class="keyword">assert</span> len(image.shape) == <span class="number">3</span></div><div class="line">            <span class="keyword">assert</span> image.shape[<span class="number">2</span>] == <span class="number">3</span></div><div class="line">            <span class="keyword">return</span> image</div><div class="line"></div><div class="line"><span class="meta">    @staticmethod</span></div><div class="line">    <span class="function"><span class="keyword">def</span> <span class="title">_is_png</span><span class="params">(filename)</span>:</span></div><div class="line">        <span class="keyword">return</span> <span class="string">'.png'</span> <span class="keyword">in</span> filename</div><div class="line"></div><div class="line">    <span class="function"><span class="keyword">def</span> <span class="title">_process_image</span><span class="params">(self, filename, coder)</span>:</span></div><div class="line">        <span class="keyword">with</span> tf.gfile.FastGFile(filename, <span class="string">'r'</span>) <span class="keyword">as</span> f:</div><div class="line">            image_data = f.read()</div><div class="line">        <span class="keyword">if</span> self._is_png(filename):</div><div class="line">            print(<span class="string">'Converting PNG to JPEG for %s'</span> % filename)</div><div class="line">            image_data = coder.png_to_jpeg(image_data)</div><div class="line">        image = coder.decode_jpeg(image_data)</div><div class="line">        <span class="keyword">assert</span> len(image.shape) == <span class="number">3</span></div><div class="line">        height = image.shape[<span class="number">0</span>]</div><div class="line">        width = image.shape[<span class="number">1</span>]</div><div class="line">        <span class="keyword">assert</span> image.shape[<span class="number">2</span>] == <span class="number">3</span></div><div class="line">        <span class="keyword">return</span> image_data, height, width</div><div class="line"></div><div class="line">    <span class="function"><span class="keyword">def</span> <span class="title">_process_image_files_batch</span><span class="params">(self, coder, thread_index, ranges, name, file_names,</span></span></div><div class="line">                                   texts, labels, num_shards):</div><div class="line">        num_threads = len(ranges)</div><div class="line">        <span class="keyword">assert</span> <span class="keyword">not</span> num_shards % num_threads</div><div class="line">        num_shards_per_batch = int(num_shards / num_threads)</div><div class="line"></div><div class="line">        shard_ranges = np.linspace(ranges[thread_index][<span class="number">0</span>],</div><div class="line">                                   ranges[thread_index][<span class="number">1</span>],</div><div class="line">                                   num_shards_per_batch + <span class="number">1</span>).astype(int)</div><div class="line">        num_files_in_thread = ranges[thread_index][<span class="number">1</span>] - ranges[thread_index][<span class="number">0</span>]</div><div class="line"></div><div class="line">        counter = <span class="number">0</span></div><div class="line">        <span class="keyword">for</span> s <span class="keyword">in</span> range(num_shards_per_batch):</div><div class="line">            shard = thread_index * num_shards_per_batch + s</div><div class="line">            output_filename = <span class="string">'%s-%.5d-of-%.5d.tfrecord'</span> % (name, shard, num_shards)</div><div class="line">            output_file = os.path.join(self.tf_records_saved_dir, output_filename)</div><div class="line">            writer = tf.python_io.TFRecordWriter(output_file)</div><div class="line"></div><div class="line">            shard_counter = <span class="number">0</span></div><div class="line">            files_in_shard = np.arange(shard_ranges[s], shard_ranges[s + <span class="number">1</span>], dtype=int)</div><div class="line">            <span class="keyword">for</span> i <span class="keyword">in</span> files_in_shard:</div><div class="line">                filename = file_names[i]</div><div class="line">                label = labels[i]</div><div class="line">                text = texts[i]</div><div class="line">                image_buffer, height, width = self._process_image(filename, coder)</div><div class="line">                example = self._convert_to_example(filename, image_buffer, label,</div><div class="line">                                                   text, height, width)</div><div class="line">                writer.write(example.SerializeToString())</div><div class="line">                shard_counter += <span class="number">1</span></div><div class="line">                counter += <span class="number">1</span></div><div class="line">                <span class="keyword">if</span> <span class="keyword">not</span> counter % <span class="number">1000</span>:</div><div class="line">                    print(<span class="string">'%s [thread %d]: Processed %d of %d images in thread batch.'</span> %</div><div class="line">                          (datetime.now(), thread_index, counter, num_files_in_thread))</div><div class="line">                    sys.stdout.flush()</div><div class="line">            writer.close()</div><div class="line">            print(<span class="string">'%s [thread %d]: Wrote %d images to %s'</span> %</div><div class="line">                  (datetime.now(), thread_index, shard_counter, output_file))</div><div class="line">            sys.stdout.flush()</div><div class="line">            shard_counter = <span class="number">0</span></div><div class="line">        print(<span class="string">'%s [thread %d]: Wrote %d images to %d shards.'</span> %</div><div class="line">              (datetime.now(), thread_index, counter, num_files_in_thread))</div><div class="line">        sys.stdout.flush()</div><div class="line"></div><div class="line">    <span class="function"><span class="keyword">def</span> <span class="title">_process_image_files</span><span class="params">(self, file_names, texts, labels)</span>:</span></div><div class="line">        <span class="keyword">assert</span> len(file_names) == len(texts)</div><div class="line">        <span class="keyword">assert</span> len(file_names) == len(labels)</div><div class="line">        spacing = np.linspace(<span class="number">0</span>, len(file_names), self.num_threads + <span class="number">1</span>).astype(np.int)</div><div class="line">        ranges = []</div><div class="line">        <span class="keyword">for</span> i <span class="keyword">in</span> range(len(spacing) - <span class="number">1</span>):</div><div class="line">            ranges.append([spacing[i], spacing[i + <span class="number">1</span>]])</div><div class="line">        print(<span class="string">'Launching %d threads for spacings: %s'</span> % (self.num_threads, ranges))</div><div class="line">        sys.stdout.flush()</div><div class="line">        coord = tf.train.Coordinator()</div><div class="line">        coder = self.ImageCoder()</div><div class="line">        threads = []</div><div class="line">        <span class="keyword">for</span> thread_index <span class="keyword">in</span> range(len(ranges)):</div><div class="line">            args = (coder, thread_index, ranges, self.name, file_names,</div><div class="line">                    texts, labels, self.num_shards)</div><div class="line">            t = threading.Thread(target=self._process_image_files_batch, args=args)</div><div class="line">            t.start()</div><div class="line">            threads.append(t)</div><div class="line">        coord.join(threads)</div><div class="line">        print(<span class="string">'%s: Finished writing all %d images in data set.'</span> %</div><div class="line">              (datetime.now(), len(file_names)))</div><div class="line">        sys.stdout.flush()</div><div class="line"></div><div class="line">    <span class="function"><span class="keyword">def</span> <span class="title">_find_image_files</span><span class="params">(self)</span>:</span></div><div class="line">        print(<span class="string">'Determining list of input files and labels from %s.'</span> % self.images_dir)</div><div class="line">        unique_labels = [l.strip() <span class="keyword">for</span> l <span class="keyword">in</span> tf.gfile.FastGFile(</div><div class="line">            self.classes_file_path, <span class="string">'r'</span>).readlines()]</div><div class="line"></div><div class="line">        labels = []</div><div class="line">        file_names = []</div><div class="line">        texts = []</div><div class="line">        label_index = <span class="number">1</span></div><div class="line"></div><div class="line">        <span class="keyword">for</span> text <span class="keyword">in</span> unique_labels:</div><div class="line">            jpeg_file_path = <span class="string">'%s/%s/*'</span> % (self.images_dir, text)</div><div class="line">            matching_files = tf.gfile.Glob(jpeg_file_path)</div><div class="line"></div><div class="line">            labels.extend([label_index] * len(matching_files))</div><div class="line">            texts.extend([text] * len(matching_files))</div><div class="line">            file_names.extend(matching_files)</div><div class="line"></div><div class="line">            <span class="keyword">if</span> <span class="keyword">not</span> label_index % <span class="number">100</span>:</div><div class="line">                print(<span class="string">'Finished finding files in %d of %d classes.'</span> % (</div><div class="line">                    label_index, len(labels)))</div><div class="line">            label_index += <span class="number">1</span></div><div class="line"></div><div class="line">        shuffled_index = list(range(len(file_names)))</div><div class="line">        random.seed(<span class="number">12345</span>)</div><div class="line">        random.shuffle(shuffled_index)</div><div class="line"></div><div class="line">        file_names = [file_names[i] <span class="keyword">for</span> i <span class="keyword">in</span> shuffled_index]</div><div class="line">        texts = [texts[i] <span class="keyword">for</span> i <span class="keyword">in</span> shuffled_index]</div><div class="line">        labels = [labels[i] <span class="keyword">for</span> i <span class="keyword">in</span> shuffled_index]</div><div class="line"></div><div class="line">        print(<span class="string">'Found %d JPEG files across %d labels inside %s.'</span> %</div><div class="line">              (len(file_names), len(unique_labels), self.images_dir))</div><div class="line">        print(<span class="string">'[INFO] Attempting logging out file_names list: &#123;&#125;'</span>.format(<span class="string">'\n'</span>.join(file_names)))</div><div class="line">        <span class="keyword">return</span> file_names, texts, labels</div><div class="line"></div><div class="line">    <span class="function"><span class="keyword">def</span> <span class="title">generate</span><span class="params">(self)</span>:</span></div><div class="line">        <span class="keyword">assert</span> <span class="keyword">not</span> self.num_shards % self.num_threads, (</div><div class="line">            <span class="string">'Please make the FLAGS.num_threads commensurate with FLAGS.train_shards'</span>)</div><div class="line">        print(<span class="string">'Saving results to %s'</span> % self.tf_records_saved_dir)</div><div class="line"></div><div class="line">        file_names, texts, labels = self._find_image_files()</div><div class="line">        self._process_image_files(file_names, texts, labels)</div><div class="line">        print(<span class="string">'All Done! Solved &#123;&#125; images. tf_records file saved into &#123;&#125;.'</span>.format(len(file_names), os.path.abspath(</div><div class="line">            self.tf_records_saved_dir)))</div></pre></td></tr></table></figure>
<p>这是我包装的一个类，只要传入路径调用generate就可以生成tfrecord文件。看到这里估计你已经哭了，尼玛这么复杂?!!!!????</p>
<p>好吧，暂且不管这个具体咋么实现的，再来看看数据怎么load进模型的吧：</p>
<figure class="highlight python"><table><tr><td class="code"><pre><div class="line"><span class="keyword">import</span> tensorflow <span class="keyword">as</span> tf</div><div class="line"><span class="keyword">import</span> logging</div><div class="line"><span class="keyword">import</span> numpy <span class="keyword">as</span> np</div><div class="line"><span class="keyword">import</span> os</div><div class="line"><span class="keyword">import</span> time</div><div class="line"><span class="keyword">from</span> datasets.tiny5.tiny5 <span class="keyword">import</span> Tiny5</div><div class="line"><span class="keyword">from</span> models.alexnet <span class="keyword">import</span> AlexNet</div><div class="line"><span class="keyword">from</span> models.vgg <span class="keyword">import</span> VGGNet</div><div class="line"><span class="keyword">from</span> models.fanet <span class="keyword">import</span> FaNet</div><div class="line"></div><div class="line">logging.basicConfig(level=logging.DEBUG,</div><div class="line">                    format=<span class="string">'%(asctime)s %(filename)s line:%(lineno)d %(levelname)s %(message)s'</span>,</div><div class="line">                    datefmt=<span class="string">'%a, %d %b %Y %H:%M:%S'</span>)</div><div class="line"></div><div class="line">tf.app.flags.DEFINE_string(<span class="string">'checkpoints_dir'</span>, <span class="string">'./checkpoints/tiny5/'</span>, <span class="string">'checkpoints save path.'</span>)</div><div class="line">tf.app.flags.DEFINE_string(<span class="string">'model_prefix'</span>, <span class="string">'tiny5-alex-net'</span>, <span class="string">'model save prefix.'</span>)</div><div class="line">tf.app.flags.DEFINE_boolean(<span class="string">'is_restore'</span>, <span class="keyword">False</span>, <span class="string">'to restore from previous or not.'</span>)</div><div class="line"></div><div class="line">tf.app.flags.DEFINE_integer(<span class="string">'target_width'</span>, <span class="number">256</span>, <span class="string">'target width for resize.'</span>)</div><div class="line">tf.app.flags.DEFINE_integer(<span class="string">'target_height'</span>, <span class="number">256</span>, <span class="string">'target height for resize.'</span>)</div><div class="line">tf.app.flags.DEFINE_integer(<span class="string">'batch_size'</span>, <span class="number">24</span>, <span class="string">'batch size for train.'</span>)</div><div class="line"></div><div class="line"></div><div class="line">FLAGS = tf.app.flags.FLAGS</div><div class="line"></div><div class="line"></div><div class="line"><span class="function"><span class="keyword">def</span> <span class="title">running</span><span class="params">(is_train=True)</span>:</span></div><div class="line">    <span class="keyword">if</span> <span class="keyword">not</span> os.path.exists(FLAGS.checkpoints_dir):</div><div class="line">        os.makedirs(FLAGS.checkpoints_dir)</div><div class="line"></div><div class="line">    tiny5 = Tiny5(</div><div class="line">        images_dir=<span class="string">'./datasets/tiny5/images'</span>,</div><div class="line">        classes_file_path=<span class="string">'./datasets/tiny5/tiny5_classes.txt'</span>,</div><div class="line">        target_height=FLAGS.target_height,</div><div class="line">        target_width=FLAGS.target_width,</div><div class="line">        batch_size=FLAGS.batch_size</div><div class="line">    )</div><div class="line">    images, labels = tiny5.batch_inputs()</div><div class="line">    print(images)</div><div class="line"></div><div class="line">    <span class="comment"># model = AlexNet(num_classes=5)</span></div><div class="line">    <span class="comment"># model = VGGNet(num_classes=5)</span></div><div class="line">    model = FaNet(num_classes=<span class="number">5</span>)</div><div class="line">    config = tf.ConfigProto()</div><div class="line">    config.gpu_options.allow_growth = <span class="keyword">True</span></div><div class="line">    saver = tf.train.Saver(max_to_keep=<span class="number">2</span>)</div><div class="line">    init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())</div><div class="line">    <span class="keyword">with</span> tf.Session() <span class="keyword">as</span> sess:</div><div class="line">        coord = tf.train.Coordinator()</div><div class="line">        threads = tf.train.start_queue_runners(sess=sess, coord=coord)</div><div class="line">        sess.run(init_op)</div><div class="line"></div><div class="line">        start_epoch = <span class="number">0</span></div><div class="line">        checkpoint = tf.train.latest_checkpoint(FLAGS.checkpoints_dir)</div><div class="line">        <span class="keyword">if</span> FLAGS.is_restore:</div><div class="line">            <span class="keyword">if</span> checkpoint:</div><div class="line">                saver.restore(sess, checkpoint)</div><div class="line">                logging.info(<span class="string">"restore from the checkpoint &#123;0&#125;"</span>.format(checkpoint))</div><div class="line">                start_epoch += int(checkpoint.split(<span class="string">'-'</span>)[<span class="number">-1</span>])</div><div class="line">        <span class="keyword">if</span> is_train:</div><div class="line">            step = <span class="number">0</span></div><div class="line">            logging.info(<span class="string">'training start...'</span>)</div><div class="line">            <span class="keyword">try</span>:</div><div class="line">                <span class="keyword">while</span> <span class="keyword">not</span> coord.should_stop():</div><div class="line">                    feed_dict = model.make_train_inputs(images, labels)</div><div class="line">                    _, loss, step = sess.run(</div><div class="line">                        [model.train_op, model.loss, model.global_step], feed_dict=feed_dict</div><div class="line">                    )</div><div class="line">                    logging.info(<span class="string">'epoch &#123;&#125;,  loss &#123;&#125;'</span>.format(step, loss))</div><div class="line"></div><div class="line">            <span class="keyword">except</span> tf.errors.OutOfRangeError:</div><div class="line">                logging.info(<span class="string">'optimization done! enjoy color net.'</span>)</div><div class="line">                saver.save(sess, os.path.join(FLAGS.checkpoints_dir, FLAGS.checkpoints_prefix), global_step=step)</div><div class="line">            <span class="keyword">except</span> KeyboardInterrupt:</div><div class="line">                logging.info(<span class="string">'interrupt manually, try saving checkpoint for now...'</span>)</div><div class="line">                saver.save(sess, os.path.join(FLAGS.checkpoints_dir, FLAGS.model_prefix), global_step=step)</div><div class="line">                logging.info(<span class="string">'last epoch were saved, next time will start from epoch &#123;&#125;.'</span>.format(step))</div><div class="line">            <span class="keyword">finally</span>:</div><div class="line">                coord.request_stop()</div><div class="line">                coord.join(threads)</div><div class="line">        <span class="keyword">else</span>:</div><div class="line">            logging.info(<span class="string">'start inference...'</span>)</div><div class="line"></div><div class="line">            inference_image_path = <span class="string">'./images/1.png'</span></div><div class="line">            input_image = tiny5.single_image_input(inference_image_path)</div><div class="line">            feed_dict = model.make_inference_inputs(input_image)</div><div class="line">            outputs = sess.run([model.inference_outputs(n_top=<span class="number">2</span>)], feed_dict=feed_dict)</div><div class="line">            print(outputs)</div><div class="line"></div><div class="line"></div><div class="line"><span class="function"><span class="keyword">def</span> <span class="title">main</span><span class="params">(args)</span>:</span></div><div class="line">    running(args)</div><div class="line"></div><div class="line"></div><div class="line"><span class="keyword">if</span> __name__ == <span class="string">'__main__'</span>:</div><div class="line">    tf.app.run()</div></pre></td></tr></table></figure>
<p>这个训练的代码，大概的训练步骤分为：</p>
<ul>
<li>使用tf.ConfigProto()来生成一个config，设置gpu自动生长，同时设置一个saver，这个saver就是最大保存的数目；</li>
<li>设置初始化的变量op，设置一个tf.Train.Coordinator()来作为训练协调者，初始化图；</li>
<li>for循环所有的epoch，在每次循环里面catch一下tf.errors.OutOfRangeError表示一个batch训练完了，catch一下KeyBoardInterrupt；</li>
<li>最后是保存模型</li>
</ul>
<p>大家可以感受一下TensorFlow一整套流程下来的复杂程度。这里面还没有写我的网络，没有写我的数据DataLoader，整个代码在我的GitHub仓库可以找到原始代码，<a href="https://github.com/jinfagang/tensorflow_classifier.git" target="_blank" rel="noopener">传送门</a>, 如果你觉得那个项目过于陈旧可以跟进我的一些最新的项目，我近期在TensorFlow上做的工作有：</p>
<ul>
<li>用Google最新nmt模型训练聊天机器人；</li>
<li>使用GAN做Cylce-GAN生成；</li>
<li>使用KnowledgeDatabase和知识图谱做问答系统；</li>
<li>目标检测和分割等常规性工作</li>
</ul>
<h2 id="4-PyTorch"><a href="#4-PyTorch" class="headerlink" title="4. PyTorch"></a>4. PyTorch</h2><p>PyTorch如果做图片预测我就不详细讲了，很多人说PyTorch很简单，但是我并没有觉得简单到哪里去，我总结一下PyTorch目前来说一些优点吧。</p>
<ul>
<li>立即式编程，也就是运行立马出结果，不同于TensorFlow的图式，你必须把所有程序写完之后才知道结果什么；</li>
<li>安装也比较方便，但是跨平台部署就比较麻烦了，这也和PyTorch的定位有关，当然PyTorch刚推出来的时候有几篇官方教程写的不错，主要是RNN文本生成，Seq2Seq翻译的实现，有兴趣的同学可以看一下，但是都是非常简单的实现，跟TensorFlow的官方例子差距蛮大；</li>
<li>只是构建网络比较简单，但是具体训练的PipeLine还是有点麻烦，尤其是我每次变量还得指定是CPU还是GPU，每次load模型的时候还得load是CPU还是GPU，个人感觉略麻烦；</li>
</ul>
<p>PyTorch推出来的时候很火，现在貌似熄火了….</p>
<h2 id="5-Caffe2"><a href="#5-Caffe2" class="headerlink" title="5. Caffe2"></a>5. Caffe2</h2><p>caffe2 不得不提一下，caffe的进化版本？？？？caffe用着还好，c++调接口还蛮方便，例子也很多，caffe2为毛主打python，还python2？？？不过这也跟caffe2定位于工业使用有关，但是总体来说有这么几点：</p>
<ul>
<li>感觉没有多少社区，虽然caffe非常多公司用，但是那毕竟是第一代版本，一般公司用用还行，容易与时代脱节；</li>
<li>caffe2也没有多少亮点，官方的教程我是没有看到什么实质性的东西，后期也没有更多的example；</li>
<li>好像C++接口也不是非常友好，至少在例子上很少….一个框架推出来，不教人去用那推出来有啥意思？</li>
</ul>
<h2 id="总结"><a href="#总结" class="headerlink" title="总结"></a>总结</h2><p>我写文章喜欢一目了然，文章结构大致对比了5种框架的优缺点，那么我直接给使用者一些建议，防止大家采坑：</p>
<ul>
<li>如果你是深度学习老鸟，你应该选择TensorFlow，但是我不得不告诉你TensorFlow在1.2版本推出来的API，在1.4版本很有可能就大改了…..</li>
<li>如果你是深度学习菜鸟，你应该选择MXNet或者PaddlePaddle，很多人会说，我曹，为什么不用Keras？？好吧，Keras当然也可以用，但是不建议一直用，还是得熟悉一下稍微底层一些的框架；</li>
<li>如果你是….如果你是小学生？高中生或者初中生，你可以用一下PaddlePaddle，因为你英文可能不太好。</li>
</ul>
<p>如果你想跟进我的更多TensorFlow项目欢迎在Github寻找我的联系方式，加入QQ群交流。</p>
<blockquote>
<p>This article was original written by Jin Tian, welcome re-post, first come with <a href="https://jinfagang.github.io" target="_blank" rel="noopener">https://jinfagang.github.io</a> . but please keep this copyright info, thanks, any question could be asked via wechat: <code>jintianiloveu</code> </p>
</blockquote>

      
    </div>
    <footer class="article-footer">
      
        <div id="donation_div"></div>

<script src="/js/vdonate.js"></script>
<script>
var a = new Donate({
  title: '骚年，加个好友打赏一下啊，现在连泡面都吃不起了啊', // 可选参数，打赏标题
  btnText: '打赏支持', // 可选参数，打赏按钮文字
  el: document.getElementById('donation_div'),
  wechatImage: 'https://i.loli.net/2017/09/27/59cb048ba6838.jpeg',
  alipayImage: 'https://i.loli.net/2017/09/27/59cb049cd0951.jpeg'
});
</script>
      
      
        
	<div id="comment">
		<!-- 来必力City版安装代码 -->
		<div id="lv-container" data-id="city" data-uid="MTAyMC8zMDA5MC82NjQ1">
		<script type="text/javascript">
		   (function(d, s) {
		       var j, e = d.getElementsByTagName(s)[0];

		       if (typeof LivereTower === 'function') { return; }

		       j = d.createElement(s);
		       j.src = 'https://cdn-city.livere.com/js/embed.dist.js';
		       j.async = true;

		       e.parentNode.insertBefore(j, e);
		   })(document, 'script');
		</script>
		<noscript>为正常使用来必力评论功能请激活JavaScript</noscript>
		</div>
		<!-- City版安装代码已完成 -->
	</div>



      
      
    </footer>
  </div>
  
    
<nav id="article-nav">
  
    <a href="/2017/11/03/Capsule下一代CNN深入探索/" id="article-nav-newer" class="article-nav-link-wrap">
      <strong class="article-nav-caption">上一篇</strong>
      <div class="article-nav-title">
        
          Capsule下一代CNN深入探索
        
      </div>
    </a>
  
  
    <a href="/2017/10/11/PaddlePaddle系列之三行代码从入门到精通/" id="article-nav-older" class="article-nav-link-wrap">
      <strong class="article-nav-caption">下一篇</strong>
      <div class="article-nav-title">PaddlePaddle系列之三行代码从入门到精通</div>
    </a>
  
</nav>

  
</article>

<!-- Table of Contents -->

  <aside id="toc-sidebar">
    <div id="toc" class="toc-article">
    <strong class="toc-title">文章目录</strong>
    
        <ol class="nav"><li class="nav-item nav-level-1"><a class="nav-link" href="#PaddlePaddle-TensorFlow-MXNet-Caffe2-PyTorch五大深度学习框架2017-10最新评测"><span class="nav-number">1.</span> <span class="nav-text">PaddlePaddle, TensorFlow, MXNet, Caffe2 , PyTorch五大深度学习框架2017-10最新评测</span></a><ol class="nav-child"><li class="nav-item nav-level-2"><a class="nav-link" href="#前言"><span class="nav-number">1.1.</span> <span class="nav-text">前言</span></a></li><li class="nav-item nav-level-2"><a class="nav-link" href="#0-五大框架概览"><span class="nav-number">1.2.</span> <span class="nav-text">0. 五大框架概览</span></a></li><li class="nav-item nav-level-2"><a class="nav-link" href="#1-MXNet"><span class="nav-number">1.3.</span> <span class="nav-text">1. MXNet</span></a></li><li class="nav-item nav-level-2"><a class="nav-link" href="#2-PaddlePaddle"><span class="nav-number">1.4.</span> <span class="nav-text">2. PaddlePaddle</span></a></li><li class="nav-item nav-level-2"><a class="nav-link" href="#3-TensorFlow"><span class="nav-number">1.5.</span> <span class="nav-text">3. TensorFlow</span></a></li><li class="nav-item nav-level-2"><a class="nav-link" href="#4-PyTorch"><span class="nav-number">1.6.</span> <span class="nav-text">4. PyTorch</span></a></li><li class="nav-item nav-level-2"><a class="nav-link" href="#5-Caffe2"><span class="nav-number">1.7.</span> <span class="nav-text">5. Caffe2</span></a></li><li class="nav-item nav-level-2"><a class="nav-link" href="#总结"><span class="nav-number">1.8.</span> <span class="nav-text">总结</span></a></li></ol></li></ol>
    
    </div>
  </aside>
</section>
        
      </div>
      
      <footer id="footer">
  

  <div class="container">
      	<div class="row">
	      <p> Powered by <a href="http://www.luoli-luoli.com/" target="_blank">萝莉萝莉</a> and <a href="http://www.luoli-luoli.com/sia" target="_blank">Sia</a> </p>
	      <p id="copyRightEn">Copyright &copy; 2017 - 2018 Jin Tian All Rights Reserved.</p>
	      
	      
    		<p class="busuanzi_uv">
				访客数 : <span id="busuanzi_value_site_uv"></span> |  
				访问量 : <span id="busuanzi_value_site_pv"></span>
		    </p>
  		   
		</div>

		
  </div>
</footer>


<!-- min height -->

<script>
    var wrapdiv = document.getElementById("wrap");
    var contentdiv = document.getElementById("content");
    var allheader = document.getElementById("allheader");

    wrapdiv.style.minHeight = document.body.offsetHeight + "px";
    if (allheader != null) {
      contentdiv.style.minHeight = document.body.offsetHeight - allheader.offsetHeight - document.getElementById("footer").offsetHeight + "px";
    } else {
      contentdiv.style.minHeight = document.body.offsetHeight - document.getElementById("footer").offsetHeight + "px";
    }
</script>
    </div>
    <!-- <nav id="mobile-nav">
  
    <a href="/" class="mobile-nav-link">Home</a>
  
    <a href="/archives" class="mobile-nav-link">Archives</a>
  
    <a href="/categories" class="mobile-nav-link">Categories</a>
  
    <a href="/tags" class="mobile-nav-link">Tags</a>
  
    <a href="/about" class="mobile-nav-link">About</a>
  
    <a href="http://luoli-luoli.com/chat" class="mobile-nav-link">Chat</a>
  
</nav> -->
    

<!-- mathjax config similar to math.stackexchange -->

<script type="text/x-mathjax-config">
  MathJax.Hub.Config({
    tex2jax: {
      inlineMath: [ ['$','$'], ["\\(","\\)"] ],
      processEscapes: true
    }
  });
</script>

<script type="text/x-mathjax-config">
    MathJax.Hub.Config({
      tex2jax: {
        skipTags: ['script', 'noscript', 'style', 'textarea', 'pre', 'code']
      }
    });
</script>

<script type="text/x-mathjax-config">
    MathJax.Hub.Queue(function() {
        var all = MathJax.Hub.getAllJax(), i;
        for(i=0; i < all.length; i += 1) {
            all[i].SourceElement().parentNode.className += ' has-jax';
        }
    });
</script>

<script type="text/javascript" src="https://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS-MML_HTMLorMML">
</script>


  <link rel="stylesheet" href="/fancybox/jquery.fancybox.css">
  <script src="/fancybox/jquery.fancybox.pack.js"></script>


<script src="/js/scripts.js"></script>




  <script src="/js/dialog.js"></script>








	<div style="display: none;">
    <script src="https://s95.cnzz.com/z_stat.php?id=1260716016&web_id=1260716016" language="JavaScript"></script>
  </div>



	<script async src="//dn-lbstatics.qbox.me/busuanzi/2.3/busuanzi.pure.mini.js">
	</script>






  </div>

  <div class="modal fade" id="myModal" tabindex="-1" role="dialog" aria-labelledby="myModalLabel" aria-hidden="true" style="display: none;">
  <div class="modal-dialog">
    <div class="modal-content">
      <div class="modal-header">
        <h2 class="modal-title" id="myModalLabel">设置</h2>
      </div>
      <hr style="margin-top:0px; margin-bottom:0px; width:80%; border-top: 3px solid #000;">
      <hr style="margin-top:2px; margin-bottom:0px; width:80%; border-top: 1px solid #000;">


      <div class="modal-body">
          <div style="margin:6px;">
            <a data-toggle="collapse" data-parent="#accordion" href="#collapseOne" onclick="javascript:setFontSize();" aria-expanded="true" aria-controls="collapseOne">
              正文字号大小
            </a>
          </div>
          <div id="collapseOne" class="panel-collapse collapse" role="tabpanel" aria-labelledby="headingOne">
          <div class="panel-body">
            您已调整页面字体大小
          </div>
        </div>
      


          <div style="margin:6px;">
            <a data-toggle="collapse" data-parent="#accordion" href="#collapseTwo" onclick="javascript:setBackground();" aria-expanded="true" aria-controls="collapseTwo">
              夜间护眼模式
            </a>
        </div>
          <div id="collapseTwo" class="panel-collapse collapse" role="tabpanel" aria-labelledby="headingTwo">
          <div class="panel-body">
            夜间模式已经开启，再次单击按钮即可关闭 
          </div>
        </div>

        <div>
            <a data-toggle="collapse" data-parent="#accordion" href="#collapseThree" aria-expanded="true" aria-controls="collapseThree">&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;关 于&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;</a>
        </div>
         <div id="collapseThree" class="panel-collapse collapse" role="tabpanel" aria-labelledby="headingThree">
          <div class="panel-body">
            Jin Tian
          </div>
          <div class="panel-body">
            Copyright © 2018 Jintian All Rights Reserved.
          </div>
        </div>
      </div>


      <hr style="margin-top:0px; margin-bottom:0px; width:80%; border-top: 1px solid #000;">
      <hr style="margin-top:2px; margin-bottom:0px; width:80%; border-top: 3px solid #000;">
      <div class="modal-footer">
        <button type="button" class="close" data-dismiss="modal" aria-label="Close"><span aria-hidden="true">×</span></button>
      </div>
    </div>
  </div>
</div>
  
  <a id="rocket" href="#top" class=""></a>
  <script type="text/javascript" src="/js/totop.js?v=1.0.0" async=""></script>
  
    <a id="menu-switch"><i class="fa fa-bars fa-lg"></i></a>
  
</body>
</html>